#!/usr/bin/env python3
"""
Scan ONLY RESULTS_FOLDER, find subfolders containing `scores.json`, group them by
model key in the folder name (e.g., 'llama', 'qwen', 'mistral'), then split each
model's runs into subgroups by run name: 'base' (default), 'sleeper', and 'mask'.

For each (model, subgroup), aggregate scores across all runs in that subgroup and
produce a single violin chart PDF:

  - Colors: HONEST = blue, DECEPTIVE = red (others = gray)
  - X-axis shows ONLY dataset short codes (no HONEST/DECEPTIVE in labels)
  - Short codes: HP-C, HP-KR, CG, ST, ID, IT, MASK
  - X tick labels rotated 45°
  - Height squeezed to ~half (figsize height = 3)
  - The same PDF is saved into EVERY run folder in that subgroup, named:
        violins_<model>__<subgroup>.pdf

The script never scans outside RESULTS_FOLDER.
"""

from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Any, Optional

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import numpy as np

# ---------------------------- Config ---------------------------------

# Extend if your run-folder names include other model tags
MODEL_KEYS = ["llama", "qwen", "mistral", "alpaca"]

# Pretty names for titles
MODEL_PRETTY = {
    "llama": "Llama",
    "qwen": "Qwen",
    "mistral": "Mistral",
    "alpaca": "Alpaca",
}

# Root folder that contains all result subfolders (relative to CWD or absolute)
RESULTS_FOLDER = "results_mean_20250923"

# Requested colors
COLOR_MAP = {"HONEST": "blue", "DECEPTIVE": "red"}
FALLBACK_COLOR = "gray"

# Fixed order for x-axis datasets
ORDERED_DATASETS = ["HP-C", "HP-KR", "CG", "ST", "ID", "IT", "MASK"]

# --------------------------- Utilities --------------------------------

def results_root() -> Path:
    """Resolve RESULTS_FOLDER (relative to CWD or absolute) and ensure it exists."""
    rf = Path(RESULTS_FOLDER)
    root = rf if rf.is_absolute() else (Path.cwd() / rf)
    root = root.resolve()
    print(f"[INFO] Using results root: {root}")
    if not root.exists():
        print(f"[ERROR] Expected a 'results' folder at: {root}", file=sys.stderr)
        sys.exit(1)
    return root

def find_run_dirs(root: Path) -> list[Path]:
    """All directories beneath root that contain a scores.json (root-contained only)."""
    runs: list[Path] = []
    for p in root.rglob("scores.json"):
        if p.is_file():
            parent = p.parent.resolve()
            # Safety: ensure inside RESULTS_FOLDER
            try:
                parent.relative_to(root)
            except ValueError:
                continue
            runs.append(parent)
    return sorted(set(runs))

def infer_model_key(folder_name: str, model_keys: List[str]) -> Optional[str]:
    """Infer model key from folder name (case-insensitive). Prefer longest match."""
    name = folder_name.lower()
    matches = [k for k in model_keys if k.lower() in name]
    if not matches:
        return None
    return sorted(matches, key=len, reverse=True)[0]

def subgroup_from_run_name(folder_name: str) -> str:
    """Classify runs into 'sleeper', 'mask', or 'base' based on folder name."""
    n = folder_name.lower()
    if "mask" in n:
        return "mask"
    if "sleeper" in n:
        return "sleeper"
    return "base"

def load_scores_json(scores_path: Path) -> dict[str, Any]:
    with scores_path.open("r") as f:
        return json.load(f)

def load_control_scores(run_dir: Path) -> Optional[list[float]]:
    """Load control dataset scores to calculate FPR threshold."""
    control_path = run_dir / "control_scores.json"
    if not control_path.exists():
        return None
    try:
        with control_path.open("r") as f:
            control_data = json.load(f)
        # Look for alpaca_XYZ__plain control dataset
        if "alpaca_XYZ__plain" in control_data:
            control_scores = control_data["alpaca_XYZ__plain"]["scores"]
            # Flatten the nested list structure and take mean of each dialogue
            flattened = []
            for dialogue_scores in control_scores:
                if isinstance(dialogue_scores, list) and len(dialogue_scores) > 0:
                    flattened.append(np.mean(dialogue_scores))
            return flattened
    except Exception as e:
        print(f"[WARN] Could not load control scores from {control_path}: {e}", file=sys.stderr)
    return None

def calculate_fpr_threshold(control_scores: list[float], fpr: float = 0.01) -> float:
    """Calculate threshold for given FPR (default 1%) using control dataset."""
    return float(np.nanquantile(control_scores, 1 - fpr))

def abbrev_dataset(name: str) -> str:
    """
    Map dataset names to short codes for the x-axis.
      Sleeper Agents            -> ST
      Insider Trading           -> IT
      Instruct_Dishonesty       -> ID
      Harm-Refusal + Choice     -> HP-C
      Harm-Refusal + Know-Rep   -> HP-KR
      Convincing Game           -> CG
      MASK                      -> MASK
    Falls back to the original if nothing matches.
    """
    s = str(name).strip().lower()
    s_norm = s.replace("_", " ").replace("-", " ").replace("  ", " ")

    # MASK first
    if "mask" in s_norm:
        return "MASK"

    # Sleeper agents
    if "sleeper" in s_norm:
        return "ST"

    # Insider Trading
    if "insider" in s_norm or "trading" in s_norm:
        return "IT"

    # Instruct_Dishonesty
    if ("instruct" in s_norm and "decept" in s_norm) or "instruct dishonesty" in s_norm:
        return "ID"

    # Harm-Refusal + Choice / Knowledge-Report
    if "harm" in s_norm or "refusal" in s_norm or "repe_honesty" in s:
        if "choice" in s_norm or s_norm.endswith(" c") or " hp c" in s_norm:
            return "HP-C"
        if "knowledge" in s_norm or "report" in s_norm or s_norm.endswith(" kr") or " hp kr" in s_norm:
            return "HP-KR"
        # Fallback generic HP
        return "HP"

    # Convincing Game (robust handling)
    if "convincing" in s_norm or "convinc" in s_norm or "cg" == s_norm or "convincing game" in s_norm:
        return "CG"

    # Soft-Trigger (if used elsewhere)
    if "soft" in s_norm or "trigger" in s_norm:
        return "ST"

    return str(name)

def flatten_dataset(payload: dict[str, Any]) -> list[tuple[str, float]]:
    """Convert {"labels":[...], "scores":[[...], ...]} into (label, score) rows."""
    labels = payload.get("labels", [])
    groups = payload.get("scores", [])
    rows: list[tuple[str, float]] = []
    for lbl, group in zip(labels, groups, strict=False):
        if not isinstance(group, list):
            continue
        norm_lbl = str(lbl).strip().upper()
        for s in group:
            try:
                rows.append((norm_lbl, float(s)))
            except Exception:
                pass
    return rows

def aggregate_for_group(run_dirs: list[Path]) -> pd.DataFrame:
    """Aggregate scores across run_dirs -> DataFrame[run, dataset, label, score]."""
    recs = []
    for run in run_dirs:
        sj = run / "scores.json"
        try:
            data = load_scores_json(sj)
        except Exception as e:
            print(f"[WARN] Could not read {sj}: {e}", file=sys.stderr)
            continue
        if not isinstance(data, dict):
            print(f"[WARN] Unexpected structure in {sj}", file=sys.stderr)
            continue

        for dataset_name, payload in data.items():
            rows = flatten_dataset(payload if isinstance(payload, dict) else {})
            ds_short = abbrev_dataset(dataset_name)
            for lbl, val in rows:
                recs.append(
                    {"run": run.name, "dataset": ds_short, "label": lbl, "score": val}
                )

    return pd.DataFrame.from_records(recs, columns=["run", "dataset", "label", "score"])

def ordered_labels(all_labels: list[str]) -> list[str]:
    """Prefer HONEST/DECEPTIVE if present; else alphabetical."""
    uniq = sorted(set(all_labels))
    preferred = [k for k in ["HONEST", "DECEPTIVE"] if k in uniq]
    rest = [label for label in uniq if label not in preferred]
    return preferred + rest

def plot_violins_pdf(df: pd.DataFrame, title: str, save_paths: list[Path], run_dirs: list[Path]) -> None:
    """
    Group violins by dataset; x-ticks show only dataset code (no label text).
    Colors: HONEST blue, DECEPTIVE red, others gray. Skips empty groups safely.
    Adds 1% FPR threshold line and positions alpaca on the right in grey.
    """
    if df.empty:
        print("[WARN] No data to plot for this subgroup.", file=sys.stderr)
        return

    # Calculate 1% FPR threshold from control scores (use first run_dir that has control scores)
    threshold_1pct_fpr = None
    for run_dir in run_dirs:
        control_scores = load_control_scores(run_dir)
        if control_scores:
            threshold_1pct_fpr = calculate_fpr_threshold(control_scores, fpr=0.01)
            print(f"[INFO] 1% FPR threshold: {threshold_1pct_fpr:.4f}")
            break
    
    if threshold_1pct_fpr is None:
        print("[WARN] Could not calculate 1% FPR threshold; no control scores found", file=sys.stderr)

    present = list(dict.fromkeys(df["dataset"].tolist()))
    
    # Check if this is alpaca model and reorder datasets to put alpaca on the right
    is_alpaca_model = "alpaca" in title.lower()
    if is_alpaca_model:
        # For alpaca, put alpaca-related datasets on the right
        alpaca_datasets = [ds for ds in present if "alpaca" in ds.lower()]
        other_datasets = [ds for ds in present if "alpaca" not in ds.lower()]
        # Use fixed order for non-alpaca datasets, then add alpaca datasets
        ordered_others = [d for d in ORDERED_DATASETS if d in other_datasets] + [d for d in other_datasets if d not in ORDERED_DATASETS]
        datasets = ordered_others + alpaca_datasets
    else:
        # Use fixed order, append any others at the end in their first-seen order
        datasets = [d for d in ORDERED_DATASETS if d in present] + [d for d in present if d not in ORDERED_DATASETS]
    
    labels = ordered_labels(df["label"].tolist())  # HONEST first, DECEPTIVE second if present

    # Build per-dataset, per-label groups; skip empties
    groups: list[list[float]] = []
    body_labels: list[str] = []     # label for each violin body
    body_datasets: list[str] = []   # dataset for each violin body

    for ds in datasets:
        sub = df[df["dataset"] == ds]
        for lbl in labels:
            vals = sub.loc[sub["label"] == lbl, "score"].dropna().tolist()
            if len(vals) == 0:
                continue
            groups.append(vals)
            body_labels.append(lbl)
            body_datasets.append(ds)

    if not groups:
        print("[WARN] All dataset/label groups empty after filtering; skipping plot.", file=sys.stderr)
        return

    # Compute x positions so each dataset's violins sit next to each other,
    # and x-ticks are centered under each dataset group.
    # Because we append in dataset-major order, groups belonging to a dataset are contiguous.
    per_ds_counts = [sum(1 for d in body_datasets if d == ds) for ds in datasets]
    positions = list(range(len(groups)))  # simple 0..N-1
    centers = []
    idx = 0
    for cnt in per_ds_counts:
        centers.append(sum(positions[idx:idx+cnt]) / max(cnt, 1))
        idx += cnt

    # Plot (height squeezed to ~half)
    fig = plt.figure(figsize=(max(10, len(datasets) * 1.6), 3))
    ax = fig.add_subplot(111)

    vp = ax.violinplot(
        groups,
        positions=positions,
        showmeans=True,
        showmedians=False,
        showextrema=True,
    )

    # Color bodies per label (DECEPTIVE=red, HONEST=blue, else gray)
    # For alpaca model, use grey for all
    for body, lbl, ds in zip(vp["bodies"], body_labels, body_datasets, strict=False):
        body.set_alpha(0.7)
        if is_alpaca_model:
            body.set_facecolor("gray")
        else:
            body.set_facecolor(COLOR_MAP.get(lbl, FALLBACK_COLOR))
        body.set_edgecolor("black")
        body.set_linewidth(0.6)

    for part_name in ("cmeans", "cmins", "cmaxes"):
        if part_name in vp:
            vp[part_name].set_color("black")
            vp[part_name].set_linewidth(0.8)

    # Draw 1% FPR threshold line if available
    if threshold_1pct_fpr is not None:
        ax.axhline(y=threshold_1pct_fpr, color='red', linestyle='--', linewidth=1.5, alpha=0.8, label='1% FPR Threshold')

    # Legend only for labels present
    if is_alpaca_model:
        # For alpaca, just show "Alpaca" in gray
        handles = [Patch(facecolor="gray", edgecolor="black", label="Alpaca")]
    else:
        present = []
        for cand in ["HONEST", "DECEPTIVE"]:
            if cand in set(body_labels):
                present.append(cand)
        others = sorted(set(body_labels) - set(present))
        present.extend(others)
        handles = [Patch(facecolor=COLOR_MAP.get(label, FALLBACK_COLOR), edgecolor="black", label=label) for label in present]
    
    # Add threshold line to legend if present
    if threshold_1pct_fpr is not None:
        from matplotlib.lines import Line2D
        handles.append(Line2D([0], [0], color='red', linestyle='--', linewidth=1.5, label='1% FPR Threshold'))
    
    if handles:
        ax.legend(handles=handles, loc="best", frameon=True, fontsize=9)

    ax.set_title(title, fontsize=12)
    ax.set_ylabel("Score", fontsize=11)

    # X ticks at dataset centers; labels are just dataset short codes
    ax.set_xticks(centers)
    ax.set_xticklabels(datasets, rotation=45, ha="right", rotation_mode="anchor")
    ax.tick_params(axis="x", pad=6, labelsize=10)
    ax.grid(True, axis="y", linestyle=":", linewidth=0.8, alpha=0.6)

    plt.tight_layout()

    for path in save_paths:
        try:
            pdf_path = path.with_suffix(".pdf")
            fig.savefig(pdf_path, dpi=300, bbox_inches="tight")
            print(f"[OK] Saved: {pdf_path}")
        except Exception as e:
            print(f"[WARN] Could not save {path}: {e}", file=sys.stderr)

    plt.close(fig)

# ----------------------------- Main -----------------------------------

def main():
    root = results_root()
    run_dirs = find_run_dirs(root)
    if not run_dirs:
        print("[ERROR] No run directories with scores.json found under the results folder.", file=sys.stderr)
        sys.exit(1)

    # Group by model key → then subgroup by run name
    by_model: dict[str, list[Path]] = {}
    skipped: list[Path] = []
    for rd in run_dirs:
        key = infer_model_key(rd.name, MODEL_KEYS)
        if key is None:
            skipped.append(rd)
        else:
            by_model.setdefault(key, []).append(rd)

    if skipped:
        print("[INFO] Runs with unknown model key (skipping):")
        for rd in skipped:
            try:
                print(f"  - {rd.relative_to(root)}")
            except Exception:
                print(f"  - {rd}")

    if not by_model:
        print("[ERROR] No model groups formed. Adjust MODEL_KEYS.", file=sys.stderr)
        sys.exit(2)

    # For each model, split into base/sleeper/mask subgroups by folder name
    for model_key, dirs in by_model.items():
        subgroups: dict[str, list[Path]] = {"base": [], "sleeper": [], "mask": []}
        for d in dirs:
            subgroups[subgroup_from_run_name(d.name)].append(d)

        for sg_name, sg_dirs in subgroups.items():
            if not sg_dirs:
                continue

            print(f"\n[GROUP] '{model_key}' • subgroup '{sg_name}' • {len(sg_dirs)} run(s)")
            for d in sg_dirs:
                try:
                    print(f"  - {d.relative_to(root)}")
                except Exception:
                    print(f"  - {d}")

            df = aggregate_for_group(sg_dirs)

            # Keep MASK only in its own plot; exclude from others
            if sg_name == "mask":
                df = df[df["dataset"] == "MASK"]
            else:
                df = df[df["dataset"] != "MASK"]  # CG (Convincing Game) remains in base

            if df.empty:
                print("[WARN] No scores after subgroup filtering; skipping this subgroup.")
                continue

            pretty = MODEL_PRETTY.get(model_key.lower(), model_key.title())
            title = f"Scores by dataset - Model '{pretty}'"

            # Save same PDF into every run folder in this subgroup
            save_paths = [d / f"violins_{model_key}__{sg_name}" for d in sg_dirs]
            plot_violins_pdf(df, title, save_paths, sg_dirs)

if __name__ == "__main__":
    main()
